import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')
from nltk.stem import WordNetLemmatizer
import json
import spacy
from tqdm import tqdm
import warnings
import argparse
from azfuse import File
import os
from collections import defaultdict
nlp = spacy.load("en_core_web_lg")
warnings.filterwarnings("ignore", category=UserWarning)

def safe_divide(a, b):
    return a / b if b != 0 else 0

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--annotation_folder", type=str, default='/<DATA_FOLDER>/amber/annotations/')
    parser.add_argument("--word_association", type=str, default='relation.json')
    parser.add_argument("--safe_words", type=str, default='safe_words.txt')
    parser.add_argument("--inference_data", type=str)
    parser.add_argument("--annotation", type=str, default='query_debug.jsonl')
    parser.add_argument("--metrics", type=str, default='metrics.txt')
    parser.add_argument("--similarity_score", type=float, default=0.8)
    parser.add_argument('--evaluation_type', choices=['a', 'g', 'd', 'de', 'da', 'dr'], help='a: all tasks and dimensions    g: generative task    d: descriminative task    de, da, dr: existence, attribute, relation')
    parser.add_argument('--output_file', type=str, default='output.json')
    args = parser.parse_args()
    return args


def check_synonyms_word(word1, word2, similarity_score):
    token1 = nlp(word1)
    token2 = nlp(word2)
    similarity = token1.similarity(token2)
    return similarity > similarity_score


def extract_nouns(text):
    lemmatizer = WordNetLemmatizer()
    tokens = nltk.word_tokenize(text)
    tagged = nltk.pos_tag(tokens)
    nouns = [lemmatizer.lemmatize(word) for word, pos in tagged if pos.startswith('NN')]
    return nouns


def init():
    metrics = {}
    if args.annotation_folder not in args.metrics:
        args.metrics = os.path.join(args.annotation_folder, args.metrics)
    with File.open(args.metrics, "r") as file:
        lines = file.readlines()

    for line in lines:
        parts = line.strip().split('=')
        if len(parts) == 2:
            variable_name = parts[0].strip()
            variable_value = eval(parts[1].strip())
            metrics[variable_name] = variable_value
            
    return metrics


def main(args):
    metrics = init()
    if args.annotation_folder not in args.word_association:
        args.word_association = os.path.join(args.annotation_folder, args.word_association)
    if args.annotation_folder not in args.safe_words:
        args.safe_words = os.path.join(args.annotation_folder, args.safe_words)
    # if args.annotation_folder not in args.annotation:
    #     args.annotation = os.path.join(args.annotation_folder, args.annotation)
    
    association = json.load(File.open(args.word_association, 'r'))
    hallucination_words = []
    for word1 in association.keys():
        hallucination_words.append(word1)
        for word2 in association[word1]:
            hallucination_words.append(word2)
            
    global_safe_words = []
    with File.open(args.safe_words, 'r') as safe_file:
        for line in safe_file:
            line = line.split('\n')[0]
            global_safe_words.append(line)

    dimension = {'g': False,'de': False, 'da': False, 'dr': False}
    if args.evaluation_type == 'a':
        for key in dimension.keys():
            dimension[key] = True
    elif args.evaluation_type == 'g':
        dimension['g'] = True
    elif args.evaluation_type == 'd':
        dimension['de'] = True
        dimension['da'] = True
        dimension['dr'] = True
    else:
        dimension[args.evaluation_type] = True
    
    if args.inference_data.endswith("json"):
        inference_data = json.load(File.open(args.inference_data, 'r'))
    else:
        inference_data = []
        with File.open(args.inference_data, 'r') as file:
            for line in file:
                inference_data.append(json.loads(line))
    if args.annotation.endswith("json"):
        ground_truth = json.load(File.open(args.annotation, 'r'))
    else:
        ground_truth = []
        with File.open(args.annotation, 'r') as file:
            for line in file:
                ground_truth.append(json.loads(line))
    id2gt = {gt['question_id']: gt for gt in ground_truth}
    output_scores = {}
    for i in tqdm(range(len(inference_data))):
        
        id = inference_data[i]['question_id']
        
        gt = id2gt[id]
        if gt['type'] == 'generative':
            nouns = extract_nouns(inference_data[i]['text'])
            after_process_nouns = []
            for noun in nouns:
                if noun in hallucination_words:
                    after_process_nouns.append(noun)
            
            safe_words = []
            safe_list = []
            for idx, word in enumerate(gt['truth']):
                safe_words += association[word]
                safe_list += [idx] * len(association[word])
                
            ha_words = []
            ha_list = []
            for idx, word in enumerate(gt['hallu']):
                ha_words += association[word]
                ha_list += [idx] * len(association[word])
            
            safe_words += gt['truth']
            safe_len = len(gt['truth'])
            safe_list += [0] * safe_len
            safe_flag_list = [0] * len(after_process_nouns)
            
            ha_words += gt['hallu']
            ha_len = len(gt['hallu'])
            ha_list += [0] * ha_len
            
            for idx, noun in enumerate(after_process_nouns):
                if noun in global_safe_words:
                    continue
                
                if noun in safe_words:
                    for j in range(len(safe_words)):
                        if noun == safe_words[j]:
                            if j < (len(safe_list) - safe_len):
                                safe_list[safe_list[j] + len(safe_list) - safe_len] = 1
                            else:
                                safe_list[j] = 1
                            break
                    continue
                
                if noun in ha_words:
                    for j in range(len(ha_words)):
                        if noun == ha_words[j]:
                            if j < (len(ha_list) - ha_len):
                                ha_list[ha_list[j] + len(ha_list) - ha_len] = 1
                            else:
                                ha_list[j] = 1
                            break
                
                for j, check_word in enumerate(ha_words):
                    if check_synonyms_word(noun, check_word, args.similarity_score):
                        if j < (len(ha_list) - ha_len):
                                ha_list[ha_list[j] + len(ha_list) - ha_len] = 1
                        else:
                            ha_list[j] = 1
                        break
                
                flag = False
                for j, check_word in enumerate(safe_words):
                    if check_synonyms_word(noun, check_word, args.similarity_score):
                        flag = True
                        if j < (len(safe_list) - safe_len):
                                safe_list[safe_list[j] + len(safe_list) - safe_len] = 1
                        else:
                            safe_list[j] = 1
                        break
                if flag == True:
                    continue
            
                safe_flag_list[idx] = 1

            metrics['chair_score'] += sum(safe_flag_list)
            metrics['chair_num'] += len(safe_flag_list)
            metrics['safe_cover_score'] += sum(safe_list[-safe_len:])
            metrics['safe_cover_num'] += len(safe_list[-safe_len:])
            metrics['hallu_cover_score'] += sum(ha_list[-ha_len:])
            metrics['hallu_cover_num'] += len(ha_list[-ha_len:])
            if sum(safe_flag_list) == 0:
                metrics['non_hallu_score'] += 1
            metrics['non_hallu_num'] += 1
        
        else:
            metrics['qa_correct_num'] += 1
            if gt['type'] == 'discriminative-attribute-state':
                metrics['as_qa_correct_num'] += 1
            elif gt['type'] == 'discriminative-attribute-number':
                metrics['an_qa_correct_num'] += 1
            elif gt['type'] == 'discriminative-attribute-action':
                metrics['aa_qa_correct_num'] += 1
            elif gt['type'] == 'discriminative-hallucination':
                metrics['ha_qa_correct_num'] += 1
            else:
                metrics['asso_qa_correct_num'] += 1
            
            truth = gt['truth']
            text = inference_data[i]['text'].lower()

            # Only keep the first sentence
            if text.find('.') != -1:
                text = text.split('.')[0]

            text = text.replace(',', '')
            words = text.split(' ')

            if 'i don' in text or 'unanswerable' in words:
                response = 'no'
            elif 'yes' in words:
                response = 'yes'
            elif 'no' in words or 'not' in words:
                response = 'no'
            else:
                response = 'yes'
            if truth == 'yes':
                if response == 'yes':
                    metrics['qa_correct_score'] += 1
                    if gt['type'] == 'discriminative-attribute-state':
                        metrics['as_qa_correct_score'] += 1
                    elif gt['type'] == 'discriminative-attribute-number':
                        metrics['an_qa_correct_score'] += 1
                    elif gt['type'] == 'discriminative-attribute-action':
                        metrics['aa_qa_correct_score'] += 1
                    elif gt['type'] == 'discriminative-hallucination':
                        metrics['ha_qa_correct_score'] += 1
                    else:
                        metrics['asso_qa_correct_score'] += 1
            else:
                metrics['qa_no_num'] += 1
                if gt['type'] == 'discriminative-attribute-state':
                    metrics['as_qa_no_num'] += 1
                elif gt['type'] == 'discriminative-attribute-number':
                    metrics['an_qa_no_num'] += 1
                elif gt['type'] == 'discriminative-attribute-action':
                    metrics['aa_qa_no_num'] += 1
                elif gt['type'] == 'discriminative-hallucination':
                    metrics['ha_qa_no_num'] += 1
                else:
                    metrics['asso_qa_no_num'] += 1
                
                if response == 'no':
                    metrics['qa_correct_score'] += 1
                    metrics['qa_no_score'] += 1
                    if gt['type'] == 'discriminative-attribute-state':
                        metrics['as_qa_correct_score'] += 1
                        metrics['as_qa_no_score'] += 1
                    elif gt['type'] == 'discriminative-attribute-number':
                        metrics['an_qa_correct_score'] += 1
                        metrics['an_qa_no_score'] += 1
                    elif gt['type'] == 'discriminative-attribute-action':
                        metrics['aa_qa_correct_score'] += 1
                        metrics['aa_qa_no_score'] += 1
                    elif gt['type'] == 'discriminative-hallucination':
                        metrics['ha_qa_correct_score'] += 1
                        metrics['ha_qa_no_score'] += 1
                    else:
                        metrics['asso_qa_correct_score'] += 1
                        metrics['asso_qa_no_score'] += 1
            
            if response == 'no':
                metrics['qa_ans_no_num'] += 1
                if gt['type'] == 'discriminative-attribute-state':
                    metrics['as_qa_ans_no_num'] += 1
                elif gt['type'] == 'discriminative-attribute-number':
                    metrics['an_qa_ans_no_num'] += 1
                elif gt['type'] == 'discriminative-attribute-action':
                    metrics['aa_qa_ans_no_num'] += 1
                elif gt['type'] == 'discriminative-hallucination':
                    metrics['ha_qa_ans_no_num'] += 1
                else:
                    metrics['asso_qa_ans_no_num'] += 1
                if truth == 'no':
                    metrics['qa_ans_no_score'] += 1
                    if gt['type'] == 'discriminative-attribute-state':
                        metrics['as_qa_ans_no_score'] += 1
                    elif gt['type'] == 'discriminative-attribute-number':
                        metrics['an_qa_ans_no_score'] += 1
                    elif gt['type'] == 'discriminative-attribute-action':
                        metrics['aa_qa_ans_no_score'] += 1
                    elif gt['type'] == 'discriminative-hallucination':
                        metrics['ha_qa_ans_no_score'] += 1
                    else:
                        metrics['asso_qa_ans_no_score'] += 1

    if dimension['g']:
        output_scores['g'] = defaultdict(dict)
        CHAIR = round(safe_divide(metrics['chair_score'], metrics['chair_num']) * 100, 1)
        Cover = round(safe_divide(metrics['safe_cover_score'], metrics['safe_cover_num']) * 100, 1)
        Ha = round(safe_divide(metrics['hallu_cover_score'], metrics['hallu_cover_num']) * 100, 1)
        Ha_p = round(100 - safe_divide(metrics['non_hallu_score'], metrics['non_hallu_num']) * 100, 1)
        print("Generative Task:")
        print("CHAIR:\t\t", CHAIR)
        print("Cover:\t\t", Cover)
        print("Hal:\t\t", Ha_p)
        print("Cog:\t\t", Ha, "\n")
        output_scores['g']['CHAIR'] = CHAIR
        output_scores['g']['Cover'] = Cover
        output_scores['g']['Hal'] = Ha
        output_scores['g']['Cog'] = Ha_p
        output_scores['g']['Count_Chair'] = metrics['chair_num']
        output_scores['g']['Count_Cover'] = metrics['safe_cover_num']
        output_scores['g']['Count_Hal'] = metrics['hallu_cover_num']
        output_scores['g']['Count_non_Hal'] = metrics['non_hallu_num']
    
    if dimension['de'] and dimension['da'] and dimension['dr']:
        output_scores['d'] =  defaultdict(dict)
        Accuracy = round(safe_divide(metrics['qa_correct_score'], metrics['qa_correct_num']) * 100, 1)
        Precision = round(safe_divide(metrics['qa_ans_no_score'], metrics['qa_ans_no_num']) * 100, 1)
        Recall = round(safe_divide(metrics['qa_no_score'], metrics['qa_no_num']) * 100, 1)
        F1 = round(2 * safe_divide((Precision/100) * (Recall/100) , ((Precision/100) + (Recall/100) + 0.0001)) * 100, 1)
        print("Descriminative Task:")
        print("Accuracy:\t", Accuracy)
        print("Precision:\t", Precision)
        print("Recall:\t\t", Recall)
        print("F1:\t\t", F1, "\n")
        output_scores['d']['Accuracy'] = Accuracy
        output_scores['d']['Precision'] = Precision
        output_scores['d']['Recall'] = Recall
        output_scores['d']['F1'] = F1
        output_scores['d']['Count'] = metrics['qa_correct_num']
    
    if dimension['de']:
        output_scores['de'] = defaultdict(dict)
        hallucination_Accuracy = round(safe_divide(metrics['ha_qa_correct_score'], metrics['ha_qa_correct_num']) * 100, 1)
        hallucination_Precision = round(safe_divide(metrics['ha_qa_ans_no_score'], metrics['ha_qa_ans_no_num']) * 100, 1)
        hallucination_Recall = round(safe_divide(metrics['ha_qa_no_score'], metrics['ha_qa_no_num']) * 100, 1)
        hallucination_F1 = round(2 * safe_divide((hallucination_Precision/100) * (hallucination_Recall/100), ((hallucination_Precision/100) + (hallucination_Recall/100) + 0.001) * 100), 1)
        print("Exsitence:")
        print("Accuracy:\t", hallucination_Accuracy)
        print("Precision:\t", hallucination_Precision)
        print("Recall:\t\t", hallucination_Recall)
        print("F1:\t\t", hallucination_F1, "\n")
        output_scores['de']['Accuracy'] = hallucination_Accuracy
        output_scores['de']['Precision'] = hallucination_Precision
        output_scores['de']['Recall'] = hallucination_Recall
        output_scores['de']['F1'] = hallucination_F1
        output_scores['de']['Count'] = metrics['ha_qa_correct_num']
    
    if dimension['da']:
        output_scores['da'] =  defaultdict(dict)
        attr_Accuracy = round(safe_divide(metrics['as_qa_correct_score'] + metrics['an_qa_correct_score'] + metrics['aa_qa_correct_score'], metrics['as_qa_correct_num'] + metrics['an_qa_correct_num'] + metrics['aa_qa_correct_num']) * 100, 1)
        attr_Precision = round(safe_divide(metrics['as_qa_ans_no_score'] + metrics['an_qa_ans_no_score'] + metrics['aa_qa_ans_no_score'], metrics['as_qa_ans_no_num'] + metrics['an_qa_ans_no_num'] + metrics['aa_qa_ans_no_num']) * 100, 1)
        attr_Recall = round(safe_divide(metrics['as_qa_no_score'] + metrics['an_qa_no_score'] + metrics['aa_qa_no_score'], metrics['as_qa_no_num'] + metrics['an_qa_no_num'] + metrics['aa_qa_no_num']) * 100, 1)
        attr_F1 = round(2 * safe_divide(attr_Precision/100 * attr_Recall/100,  attr_Precision/100) + ((attr_Recall/100) + 0.0001) * 100, 1)
        state_Accuracy = round(safe_divide(metrics['as_qa_correct_score'], metrics['as_qa_correct_num'])* 100, 1)
        state_Precision = round(safe_divide(metrics['as_qa_ans_no_score'], metrics['as_qa_ans_no_num'] * 100), 1)
        state_Recall = round(safe_divide(metrics['as_qa_no_score'], metrics['as_qa_no_num'] * 100), 1)
        state_F1 = round(2 * safe_divide((state_Precision/100) * (state_Recall/100), ((state_Precision/100)) + (state_Recall/100) + 0.0001) * 100, 1)
        number_Accuracy = round(safe_divide(metrics['an_qa_correct_score'], metrics['an_qa_correct_num']) * 100, 1)
        number_Precision = round(safe_divide(metrics['an_qa_ans_no_score'], metrics['an_qa_ans_no_num'] * 100), 1)
        number_Recall = round(safe_divide(metrics['an_qa_no_score'], metrics['an_qa_no_num'] * 100), 1)
        number_F1 = round(2 * safe_divide((number_Precision/100) * (number_Recall/100), ((number_Precision/100) + (number_Recall/100) + 0.0001)) * 100, 1)
        action_Accuracy = round(safe_divide(metrics['aa_qa_correct_score'], metrics['aa_qa_correct_num'] * 100), 1)
        action_Precision = round(safe_divide(metrics['aa_qa_ans_no_score'], metrics['aa_qa_ans_no_num']) * 100, 1)
        action_Recall = round(safe_divide(metrics['aa_qa_no_score'], metrics['aa_qa_no_num'] * 100), 1)
        action_F1 = round(2 * safe_divide((action_Precision/100) * (action_Recall/100) , ((action_Precision/100) + (action_Recall/100) + 0.0001)) * 100, 1)
        print("Attribute:")
        print("Accuracy:\t", attr_Accuracy)
        print("Precision:\t", attr_Precision)
        print("Recall:\t\t", attr_Recall)
        print("F1:\t\t", attr_F1, "\n")
        output_scores['da']['Attribute'] = defaultdict(dict)
        output_scores['da']['Attribute']['Accuracy'] = attr_Accuracy
        output_scores['da']['Attribute']['Precision'] = attr_Precision
        output_scores['da']['Attribute']['Recall'] = attr_Recall
        output_scores['da']['Attribute']['F1'] = attr_F1
        output_scores['da']['Attribute']['Count'] = metrics['as_qa_correct_num'] + metrics['an_qa_correct_num'] + metrics['aa_qa_correct_num']

        print("State:")
        print("Accuracy:\t", state_Accuracy)
        print("Precision:\t", state_Precision)
        print("Recall:\t\t", state_Recall)
        print("F1:\t\t", state_F1, "\n")
        output_scores['da']['State'] = defaultdict(dict)
        output_scores['da']['State']['Accuracy'] = state_Accuracy
        output_scores['da']['State']['Precision'] = state_Precision
        output_scores['da']['State']['Recall'] = state_Recall
        output_scores['da']['State']['F1'] = state_F1
        output_scores['da']['State']['Count'] = metrics['as_qa_correct_num']
        print("Number:")
        print("Accuracy:\t", number_Accuracy)
        print("Precision:\t", number_Precision)
        print("Recall:\t\t", number_Recall)
        print("F1:\t\t", number_F1, "\n")
        output_scores['da']['Number'] = defaultdict(dict)
        output_scores['da']['Number']['Accuracy'] = number_Accuracy
        output_scores['da']['Number']['Precision'] = number_Precision
        output_scores['da']['Number']['Recall'] = number_Recall
        output_scores['da']['Number']['F1'] = number_F1
        output_scores['da']['Number']['Count'] = metrics['an_qa_correct_num']
        print("Action:")
        print("Accuracy:\t", action_Accuracy)
        print("Precision:\t", action_Precision)
        print("Recall:\t\t", action_Recall)
        print("F1:\t\t", action_F1, "\n")
        output_scores['da']['Action'] = defaultdict(dict)
        output_scores['da']['Action']['Accuracy'] = action_Accuracy
        output_scores['da']['Action']['Precision'] = action_Precision
        output_scores['da']['Action']['Recall'] = action_Recall
        output_scores['da']['Action']['F1'] = action_F1
        output_scores['da']['Action']['Count'] = metrics['aa_qa_correct_num']

    
    if dimension['dr']:
        output_scores['dr'] = defaultdict(dict)
        relation_Accuracy = round(safe_divide(metrics['asso_qa_correct_score'], metrics['asso_qa_correct_num']) * 100, 1)
        relation_Precision = round(safe_divide(metrics['asso_qa_ans_no_score'], metrics['asso_qa_ans_no_num']) * 100, 1)
        relation_Recall = round(safe_divide(metrics['asso_qa_no_score'], metrics['asso_qa_no_num']) * 100, 1)
        relation_F1 = round(2 * safe_divide( (relation_Precision/100) * (relation_Recall/100) , ((relation_Precision/100) + (relation_Recall/100) + 0.0001) * 100), 1)
        print("Relation:")
        print("Accuracy:\t", relation_Accuracy)
        print("Precision:\t", relation_Precision)
        print("Recall:\t\t", relation_Recall)
        print("F1:\t\t", relation_F1)
        output_scores['dr']['Accuracy'] = relation_Accuracy
        output_scores['dr']['Precision'] = relation_Precision
        output_scores['dr']['Recall'] = relation_Recall
        output_scores['dr']['F1'] = relation_F1
        output_scores['dr']['Count'] = metrics['asso_qa_correct_num']
    with File.open(args.output_file, 'w') as file:
        json.dump(output_scores, file, indent=4)

if __name__ == "__main__":
    args = get_args()
    main(args)
